package srv

import (
	"context"
	"gorm.io/gorm"
)

var (
	ContextKeyTx = "context_key_tx"
)

// 将事务对象设置到上下文中
func (b *BaseSrv) WithTx(ctx context.Context, db *gorm.DB) context.Context {
	return context.WithValue(ctx, ContextKeyTx, db)
}

// 获取事务对象
func (b *BaseSrv) GetTx(ctx context.Context, defaultDb *gorm.DB) *gorm.DB {
	value := ctx.Value(ContextKeyTx)
	if value != nil {
		db, ok := value.(*gorm.DB)
		if ok {
			return db
		}
	}
	return defaultDb
}

type BaseSrv struct {
}
